{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# RIGR: Resonance Invariant Graph Representation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "RIGR is introduced and discussed in our work [RIGR: Resonance Invariant Graph Representation for Molecular Property Prediction](). It is a featurizer implemented as part of Chemprop v2.1.2, designed to impose strict resonance invariance for molecular property prediction tasks. It ensures a single graph representation of different resonance structures of the same molecule, including non-equivalent resonance forms. For CLI users, RIGR is available as a choice for the multi-hot atom featurization scheme. To use RIGR, add the following argument to your training or inference script:\n",
    "   ```bash\n",
    "   --multi-hot-atom-featurizer-mode RIGR\n",
    "   ```\n",
    "In this Jupyter notebook, we show how to train and infer models using RIGR which is very similar to the generic training [example](./training.ipynb). RIGR can be easily implemented in your existing code by changing the `SimpleMoleculeMolGraphFeaturizer()` to this:\n",
    "   ```python\n",
    "   rigr_atom_featurizer = RIGRAtomFeaturizer()\n",
    "   rigr_bond_featurizer = RIGRBondFeaturizer()\n",
    "   featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer(atom_featurizer=rigr_atom_featurizer, bond_featurizer=rigr_bond_featurizer)\n",
    "   ```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/chemprop/chemprop/blob/main/examples/rigr_featurizer.ipynb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Install chemprop from GitHub if running in Google Colab\n",
    "import os\n",
    "\n",
    "if os.getenv(\"COLAB_RELEASE_TAG\"):\n",
    "    try:\n",
    "        import chemprop\n",
    "    except ImportError:\n",
    "        !git clone https://github.com/chemprop/chemprop.git\n",
    "        %cd chemprop\n",
    "        !pip install .\n",
    "        %cd examples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "from typing import Sequence\n",
    "\n",
    "from lightning import pytorch as pl\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from rdkit import Chem\n",
    "from rdkit.Chem.rdchem import Atom, Bond, Mol\n",
    "\n",
    "from chemprop import data, featurizers, models, nn\n",
    "from chemprop.featurizers.atom import RIGRAtomFeaturizer\n",
    "from chemprop.featurizers.bond import RIGRBondFeaturizer\n",
    "from chemprop.featurizers.molecule import ChargeFeaturizer\n",
    "from chemprop.utils import make_mol"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "chemprop_dir = Path.cwd().parent\n",
    "input_path = chemprop_dir / \"tests\" / \"data\" / \"regression\" / \"mol\" / \"mol.csv\" # path to your data .csv file\n",
    "num_workers = 0 # number of workers for dataloader. 0 means using main process for data loading\n",
    "smiles_column = 'smiles' # name of the column containing SMILES strings\n",
    "target_columns = ['lipo'] # list of names of the columns containing targets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>smiles</th>\n",
       "      <th>lipo</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Cn1c(CN2CCN(CC2)c3ccc(Cl)cc3)nc4ccccc14</td>\n",
       "      <td>3.54</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>COc1cc(OC)c(cc1NC(=O)CSCC(=O)O)S(=O)(=O)N2C(C)...</td>\n",
       "      <td>-1.18</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>COC(=O)[C@@H](N1CCc2sccc2C1)c3ccccc3Cl</td>\n",
       "      <td>3.69</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>OC[C@H](O)CN1C(=O)C(Cc2ccccc12)NC(=O)c3cc4cc(C...</td>\n",
       "      <td>3.37</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>Cc1cccc(C[C@H](NC(=O)c2cc(nn2C)C(C)(C)C)C(=O)N...</td>\n",
       "      <td>3.10</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>95</th>\n",
       "      <td>CC(C)N(CCCNC(=O)Nc1ccc(cc1)C(C)(C)C)C[C@H]2O[C...</td>\n",
       "      <td>2.20</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>96</th>\n",
       "      <td>CCN(CC)CCCCNc1ncc2CN(C(=O)N(Cc3cccc(NC(=O)C=C)...</td>\n",
       "      <td>2.04</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>97</th>\n",
       "      <td>CCSc1c(Cc2ccccc2C(F)(F)F)sc3N(CC(C)C)C(=O)N(C)...</td>\n",
       "      <td>4.49</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>98</th>\n",
       "      <td>COc1ccc(Cc2c(N)n[nH]c2N)cc1</td>\n",
       "      <td>0.20</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>99</th>\n",
       "      <td>CCN(CCN(C)C)S(=O)(=O)c1ccc(cc1)c2cnc(N)c(n2)C(...</td>\n",
       "      <td>2.00</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>100 rows × 2 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "                                               smiles  lipo\n",
       "0             Cn1c(CN2CCN(CC2)c3ccc(Cl)cc3)nc4ccccc14  3.54\n",
       "1   COc1cc(OC)c(cc1NC(=O)CSCC(=O)O)S(=O)(=O)N2C(C)... -1.18\n",
       "2              COC(=O)[C@@H](N1CCc2sccc2C1)c3ccccc3Cl  3.69\n",
       "3   OC[C@H](O)CN1C(=O)C(Cc2ccccc12)NC(=O)c3cc4cc(C...  3.37\n",
       "4   Cc1cccc(C[C@H](NC(=O)c2cc(nn2C)C(C)(C)C)C(=O)N...  3.10\n",
       "..                                                ...   ...\n",
       "95  CC(C)N(CCCNC(=O)Nc1ccc(cc1)C(C)(C)C)C[C@H]2O[C...  2.20\n",
       "96  CCN(CC)CCCCNc1ncc2CN(C(=O)N(Cc3cccc(NC(=O)C=C)...  2.04\n",
       "97  CCSc1c(Cc2ccccc2C(F)(F)F)sc3N(CC(C)C)C(=O)N(C)...  4.49\n",
       "98                        COc1ccc(Cc2c(N)n[nH]c2N)cc1  0.20\n",
       "99  CCN(CCN(C)C)S(=O)(=O)c1ccc(cc1)c2cnc(N)c(n2)C(...  2.00\n",
       "\n",
       "[100 rows x 2 columns]"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_input = pd.read_csv(input_path)\n",
    "df_input"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "smis = df_input.loc[:, smiles_column].values\n",
    "ys = df_input.loc[:, target_columns].values"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Featurization and Make Dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "RIGR uses only the subset of atom and bond features from Chemprop that remain invariant across different resonance forms. The tables below indicate which atom and bond features are present and absent in RIGR.\n",
    "\n",
    "### Atom Features\n",
    "\n",
    "| **Feature**            | **Description**                                                                 | **Present in RIGR?** |\n",
    "|------------------------|---------------------------------------------------------------------------------|:--------------------:|\n",
    "| Atomic&nbsp;number     | The choice for atom type denoted by atomic number                                | ☑️                   |\n",
    "| Degree                 | Number of direct neighbors of the atom                                           | ☑️                    |\n",
    "| Formal&nbsp;charge     | Integer charge assigned to the atom                                              | ☐                   |\n",
    "| Chiral&nbsp;tag        | The choices for an atom's chiral tag (See `rdkit.Chem.rdchem.ChiralType`)        | ☐                   |\n",
    "| Number&nbsp;of&nbsp;H  | Number of bonded hydrogen atoms                                                  | ☑️                   |\n",
    "| Hybridization          | Atom's hybridization type (See `rdkit.Chem.rdchem.HybridizationType`)            | ☐                   |\n",
    "| Aromaticity            | Indicates whether the atom is aromatic or not                                    | ☐                   |\n",
    "| Atomic&nbsp;mass       | The atomic mass of the atom                                                      | ☑️                   |\n",
    "\n",
    "\n",
    "### Bond Features\n",
    "\n",
    "| **Feature**           | **Description**                                                                                      | **Present in RIGR?** |\n",
    "|-----------------------|------------------------------------------------------------------------------------------------------|:--------------------:|\n",
    "| Bond&nbsp;type        | The known bond types: single, double, or triple bond                                                 | ☐                   |\n",
    "| Conjugation           | Indicates whether the bond is conjugated or not                                                     | ☐                   |\n",
    "| Ring                  | Indicates whether the bond is a part of a ring                                                      | ☑️                    |\n",
    "| Stereochemistry       | Stores the known bond stereochemistries (See [BondStereo](https://www.rdkit.org/docs/source/rdkit.Chem.rdchem.html#rdkit.Chem.rdchem.BondStereo.values)) | ☐                    |"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "The return type of make_split_indices has changed in v2.1 - see help(make_split_indices)\n"
     ]
    }
   ],
   "source": [
    "mols = [make_mol(smi, add_h=True, keep_h=True) for smi in smis]\n",
    "\n",
    "charge_featurizer = ChargeFeaturizer()\n",
    "x_ds = [charge_featurizer(mol) for mol in mols]\n",
    "\n",
    "all_data = [data.MoleculeDatapoint(mol, name=smi, y=y, x_d=x_d) for mol, smi, y, x_d in zip(mols, smis, ys, x_ds)]\n",
    "train_indices, val_indices, test_indices = data.make_split_indices(mols, \"random\", (0.8, 0.1, 0.1))\n",
    "train_data, val_data, test_data = data.split_data_by_indices(\n",
    "    all_data, train_indices, val_indices, test_indices\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "rigr_atom_featurizer = RIGRAtomFeaturizer()\n",
    "rigr_bond_featurizer = RIGRBondFeaturizer()\n",
    "featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer(atom_featurizer=rigr_atom_featurizer, bond_featurizer=rigr_bond_featurizer)\n",
    "\n",
    "train_dset = data.MoleculeDataset(train_data[0], featurizer)\n",
    "scaler = train_dset.normalize_targets()\n",
    "\n",
    "val_dset = data.MoleculeDataset(val_data[0], featurizer)\n",
    "val_dset.normalize_targets(scaler)\n",
    "\n",
    "test_dset = data.MoleculeDataset(test_data[0], featurizer)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Dataloader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_loader = data.build_dataloader(train_dset, num_workers=num_workers)\n",
    "val_loader = data.build_dataloader(val_dset, num_workers=num_workers, shuffle=False)\n",
    "test_loader = data.build_dataloader(test_dset, num_workers=num_workers, shuffle=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "MPNN(\n",
       "  (message_passing): BondMessagePassing(\n",
       "    (W_i): Linear(in_features=54, out_features=300, bias=False)\n",
       "    (W_h): Linear(in_features=300, out_features=300, bias=False)\n",
       "    (W_o): Linear(in_features=352, out_features=300, bias=True)\n",
       "    (dropout): Dropout(p=0.0, inplace=False)\n",
       "    (tau): ReLU()\n",
       "    (V_d_transform): Identity()\n",
       "    (graph_transform): Identity()\n",
       "  )\n",
       "  (agg): MeanAggregation()\n",
       "  (bn): BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  (predictor): RegressionFFN(\n",
       "    (ffn): MLP(\n",
       "      (0): Sequential(\n",
       "        (0): Linear(in_features=301, out_features=300, bias=True)\n",
       "      )\n",
       "      (1): Sequential(\n",
       "        (0): ReLU()\n",
       "        (1): Dropout(p=0.0, inplace=False)\n",
       "        (2): Linear(in_features=300, out_features=1, bias=True)\n",
       "      )\n",
       "    )\n",
       "    (criterion): MSE(task_weights=[[1.0]])\n",
       "    (output_transform): UnscaleTransform()\n",
       "  )\n",
       "  (X_d_transform): Identity()\n",
       "  (metrics): ModuleList(\n",
       "    (0): RMSE(task_weights=[[1.0]])\n",
       "    (1): MAE(task_weights=[[1.0]])\n",
       "    (2): MSE(task_weights=[[1.0]])\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mp = nn.BondMessagePassing(\n",
    "    d_v=featurizer.atom_fdim,\n",
    "    d_e=featurizer.bond_fdim,\n",
    ")\n",
    "agg = nn.MeanAggregation()\n",
    "output_transform = nn.UnscaleTransform.from_standard_scaler(scaler)\n",
    "ffn = nn.RegressionFFN(\n",
    "    input_dim=mp.output_dim + train_dset.d_xd,\n",
    "    output_transform=output_transform,\n",
    ")\n",
    "batch_norm = True\n",
    "metric_list = [nn.metrics.RMSE(), nn.metrics.MAE()] # Only the first metric is used for training and early stopping\n",
    "mpnn = models.MPNN(mp, agg, ffn, batch_norm, metric_list)\n",
    "mpnn"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Trainer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/akshatz/anaconda3/envs/chemprop/lib/python3.12/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/akshatz/anaconda3/envs/chemprop/lib/python3.12 ...\n",
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "HPU available: False, using: 0 HPUs\n"
     ]
    }
   ],
   "source": [
    "trainer = pl.Trainer(\n",
    "    logger=False,\n",
    "    enable_checkpointing=True, # Use `True` if you want to save model checkpoints. The checkpoints will be saved in the `checkpoints` folder.\n",
    "    enable_progress_bar=True,\n",
    "    accelerator=\"auto\",\n",
    "    devices=1,\n",
    "    max_epochs=20, # number of epochs to train for\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Start Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "You are using a CUDA device ('NVIDIA GeForce RTX 4090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision\n",
      "/home/akshatz/anaconda3/envs/chemprop/lib/python3.12/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory /home/akshatz/chemprop/examples/checkpoints exists and is not empty.\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n",
      "Loading `train_dataloader` to estimate number of stepping batches.\n",
      "/home/akshatz/anaconda3/envs/chemprop/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=63` in the `DataLoader` to improve performance.\n",
      "\n",
      "  | Name            | Type               | Params | Mode \n",
      "---------------------------------------------------------------\n",
      "0 | message_passing | BondMessagePassing | 212 K  | train\n",
      "1 | agg             | MeanAggregation    | 0      | train\n",
      "2 | bn              | BatchNorm1d        | 600    | train\n",
      "3 | predictor       | RegressionFFN      | 90.9 K | train\n",
      "4 | X_d_transform   | Identity           | 0      | train\n",
      "5 | metrics         | ModuleList         | 0      | train\n",
      "---------------------------------------------------------------\n",
      "303 K     Trainable params\n",
      "0         Non-trainable params\n",
      "303 K     Total params\n",
      "1.214     Total estimated model params size (MB)\n",
      "25        Modules in train mode\n",
      "0         Modules in eval mode\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sanity Checking DataLoader 0:   0%|          | 0/1 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/akshatz/anaconda3/envs/chemprop/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=63` in the `DataLoader` to improve performance.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 19: 100%|██████████| 2/2 [00:00<00:00, 39.68it/s, train_loss_step=0.302, val_loss=0.820, train_loss_epoch=0.164]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "`Trainer.fit` stopped: `max_epochs=20` reached.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 19: 100%|██████████| 2/2 [00:00<00:00, 34.73it/s, train_loss_step=0.302, val_loss=0.820, train_loss_epoch=0.164]\n"
     ]
    }
   ],
   "source": [
    "trainer.fit(mpnn, train_loader, val_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Test Results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n",
      "/home/akshatz/anaconda3/envs/chemprop/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=63` in the `DataLoader` to improve performance.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Testing DataLoader 0: 100%|██████████| 1/1 [00:00<00:00, 299.42it/s]\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃<span style=\"font-weight: bold\">        Test metric        </span>┃<span style=\"font-weight: bold\">       DataLoader 0        </span>┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">         test/mae          </span>│<span style=\"color: #800080; text-decoration-color: #800080\">    0.6620886325836182     </span>│\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">         test/rmse         </span>│<span style=\"color: #800080; text-decoration-color: #800080\">    0.9359426498413086     </span>│\n",
       "└───────────────────────────┴───────────────────────────┘\n",
       "</pre>\n"
      ],
      "text/plain": [
       "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃\u001b[1m \u001b[0m\u001b[1m       Test metric       \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m      DataLoader 0       \u001b[0m\u001b[1m \u001b[0m┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│\u001b[36m \u001b[0m\u001b[36m        test/mae         \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m   0.6620886325836182    \u001b[0m\u001b[35m \u001b[0m│\n",
       "│\u001b[36m \u001b[0m\u001b[36m        test/rmse        \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m   0.9359426498413086    \u001b[0m\u001b[35m \u001b[0m│\n",
       "└───────────────────────────┴───────────────────────────┘\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "results = trainer.test(mpnn, test_loader)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "chemprop",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
